name: resnet50-randaug

resources:
  candidates:
  - {accelerators: T4:4}
  - {accelerators: V100:4}

workdir: ./examples/benchmark/timm

setup: |
  conda create -n timm python=3.8 -y
  conda activate timm

  # Install SkyCallback
  pip install "git+https://github.com/skypilot-org/skypilot.git#egg=sky-callback&subdirectory=sky/callbacks/"

  # User setup
  git clone https://github.com/rwightman/pytorch-image-models.git timm
  cd timm
  git checkout v0.5.4
  pip install -r requirements.txt

  # Apply the patch to enable SkyCallback
  git apply ../callback.patch

  # Apply the patch to use a dummy ImageNet dataset to avoid data downloading
  git apply ../dummy_dataset.patch

run: |
  conda activate timm
  cd timm
  python3 -m torch.distributed.launch --nproc_per_node=4 train.py \
    -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 \
    --amp --remode pixel --reprob 0.6 --aug-splits 3 \
    --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd \
    --dist-bn reduce
